library(rpart)
library(rattle)
## Warning: package 'rattle' was built under R version 4.2.2
## Loading required package: tibble
## Loading required package: bitops
## Rattle: A free graphical interface for data science with R.
## VersiĂ³n 5.5.1 Copyright (c) 2006-2021 Togaware Pty Ltd.
## Escriba 'rattle()' para agitar, sacudir y  rotar sus datos.
library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✔ ggplot2 3.3.6     ✔ dplyr   1.0.9
## ✔ tidyr   1.2.0     ✔ stringr 1.4.0
## ✔ readr   2.1.2     ✔ forcats 0.5.1
## ✔ purrr   0.3.4
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
library(plotly)
## 
## Attaching package: 'plotly'
## The following object is masked from 'package:ggplot2':
## 
##     last_plot
## The following object is masked from 'package:stats':
## 
##     filter
## The following object is masked from 'package:graphics':
## 
##     layout
#Vamos a crear un dataset sintético y graficarlo en 3D
set.seed(911)

n = 1000

dtrain <- data.frame(x1 = runif(n,4.5,13.5),x2 = runif(n,4.5,13.5))
noise <- rnorm(n, mean=0, sd=0.5)
dtrain <- dtrain %>% mutate(y = sqrt((x1-9)**2+(x2-9)**2)+noise)


#VisualizaciĂ³n del dataset sintĂ©tico
plot_ly(dtrain, x = ~x1, y = ~x2, z = ~y) %>%
  add_markers(size = 1,color = I("orange"))
plot_ly(dtrain, x = ~x1, y = ~x2) %>%
  add_markers(size = 1,color = I("orange"))

$ Y = f(x) += + $

tree <- rpart(y ~ x1 + x2, data = dtrain, method = "anova",maxdepth = 3, minsplit = 1, minbucket = 1, cp = 0)
fancyRpartPlot(tree)

fitted.values <- predict(tree, newdata = dtrain)

frame <- tree$frame
nodevec <- as.numeric(row.names(frame[frame$var == "<leaf>",])) #esto genera un vector con los nĂºmeros de nodos terminales
path.list <- path.rpart(tree, nodes = nodevec) #genera una lista en la cual cada elemento indica el camino a un nodo
## 
##  node number: 8 
##    root
##    x2>=6.26
##    x2< 11.41
##    x1< 11.73
## 
##  node number: 9 
##    root
##    x2>=6.26
##    x2< 11.41
##    x1>=11.73
## 
##  node number: 10 
##    root
##    x2>=6.26
##    x2>=11.41
##    x2< 12.49
## 
##  node number: 11 
##    root
##    x2>=6.26
##    x2>=11.41
##    x2>=12.49
## 
##  node number: 12 
##    root
##    x2< 6.26
##    x1>=6.345
##    x1< 11.73
## 
##  node number: 13 
##    root
##    x2< 6.26
##    x1>=6.345
##    x1>=11.73
## 
##  node number: 14 
##    root
##    x2< 6.26
##    x1< 6.345
##    x2>=5.699
## 
##  node number: 15 
##    root
##    x2< 6.26
##    x1< 6.345
##    x2< 5.699
rect_info <- NULL
for(path in path.list){
  path <- setdiff(path,"root")
  min.x1 = min(dtrain$x1)
  max.x1 = max(dtrain$x1)
  min.x2 = min(dtrain$x2)
  max.x2 = max(dtrain$x2)
  for(split in path){
    s <- unlist(str_split(split,"< |>="))
    var <- s[1]
    cutoff <- as.numeric(s[2])
    is.less <- str_detect(split,"< ")
    if(var == "x1"){
      if(is.less == TRUE){
        max.x1 <- cutoff
      } else {
        min.x1 <- cutoff
      }
    } else {
      if(is.less == TRUE){
        max.x2 <- cutoff 
      } else {
        min.x2 <- cutoff
      }
    }
  }
  rect_info <- rbind(rect_info,data.frame(xmin = min.x1, xmax = max.x1, ymin = min.x2, ymax = max.x2))
}

# rect_info <- rect_info %>% mutate(xmed = xmin+((xmax - xmin)/2),
#                                   ymed = ymin + ((ymax - ymin)/2),
#                                   val = levels(as.factor(round(fitted.values,2))))
# 
# ggplot() +
#   geom_rect(data = rect_info,aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax),colour = "grey50", fill = "white") +
#   geom_point(data = dtrain,aes(x = x1, y = x2, color = round(fitted.values, 2))) +
#   labs(color="Valor ajustado")+
#   geom_label(rect_info, aes(x=xmed, y=ymed, label=val))+
#   theme_light()
dtrain <- dtrain %>% 
  mutate(fitted.values = as.factor(round(fitted.values,2)))

label_points <- dtrain %>%
  group_by(fitted.values) %>%
  summarise(x = median(x1), y = median (x2))

ggplot() +
  geom_rect(data = rect_info,aes(xmin = xmin, xmax = xmax, ymin = ymin, ymax = ymax),colour = "grey50", fill = "white") +
  geom_point(data = dtrain,aes(x = x1, y = x2, color = fitted.values)) +
  geom_label(data = label_points,aes(x = x, y = y, label = fitted.values)) +
  labs(color="Valor ajustado") +
  theme_light()